from abc import ABC
from abc import abstractmethod
from gym import spaces
import numpy as np

class PbRobot(ABC):

    def __init__(self, bullet_client, physclient, urdf_path, start_pos, robot_loader=None):
        self.bullet_client = bullet_client
        self.phys_client = physclient

        if robot_loader is None:
            self.bodyId = self.load_robot(urdf_path, start_pos)
        else:
            self.bodyId = robot_loader(urdf_path, start_pos)

        self.default_joint_states = self.getPose()
        self.defaultLink = -1

    def load_robot(self, urdf_path, pose):
        return self.bullet_client.loadURDF(fileName=urdf_path,
                                    basePosition=pose[:3], 
                                    baseOrientation=pose[3:],
                                    physicsClientId=self.phys_client)
    
    def getPose(self):
        '''
        Pose returns the base plus all the active joint positions of the robot
        '''
        return self.bullet_client.getBasePositionAndOrientation(bodyUniqueId=self.bodyId,
                                                                physicsClientId=self.phys_client)
    
    def resetPose(self, translation, rotation):
        self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,
                                                posObj=translation,
                                                ornObj=rotation,
                                                physicsClientId=self.phys_client)

    def controller(self, targets):
        '''
        Define a custom controller or pass-through to the default pybullet controller
        '''
        return NotImplementedError

    # @property
    # @abstractmethod
    # def actions(self):
    #     return NotImplementedError

    # @abstractmethod
    # def getJointStates(self):
    #     return NotImplementedError
    
    # @abstractmethod
    # def setJoints(self, targetPos):
    #     return NotImplementedError
    
    # @abstractmethod
    # def resetJoints(self, joint_states):
    #     return NotImplementedError

class Husky(PbRobot):
    '''
    FIXME: Won't work with the new API
    '''

    action_space = spaces.Box(low=np.array([-1,-1,-1]),
                              high=np.array([1,1,1]))

    def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
        self.maxJointForce = 1.
        self.basejoints = [0,1,2] # phantom joints
        self.additionalJoints = []
        self.JointIdx = self.basejoints + self.additionalJoints
        self.baseLink = 4
        super().__init__(bullet_client, physclient, urdf_path, start_pos, robot_loader=self.load_robot)
        self.defaultLink = self.baseLink
        
    @property
    def actions(self):
        return len(self.JointIdx)

    def load_robot(self, urdf_path, pose):
        # Load the base at 0,0,0
        # Movement happens through pose resetting via joint states
        # for this robot
        return self.bullet_client.loadURDF(fileName=urdf_path,
                                    basePosition=[0.,0.,0.], 
                                    baseOrientation=[0.,0.,0.,1.],
                                    physicsClientId=self.phys_client)

    def getPose(self):

        pos, orn  = self.bullet_client.getBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client)
        orn_euler = self.bullet_client.getEulerFromQuaternion(orn)
        return [pos[0],pos[1],orn_euler[2]]

    
    def setPose(self,pose):

        pos = [pose[0],pose[1],0.0]
        orn = [0,0,pose[2]]
        quat = self.bullet_client.getQuaternionFromEuler(orn)
        self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client,posObj=pos,ornObj=quat)
        return pose

class LimoRobot(PbRobot):
    '''
    FIXME: Won't work with the new API
    '''

    action_space = spaces.Box(low=np.array([-0.02,-0.02,-0.0314]),
                              high=np.array([0.02,0.02,0.0314]))

    def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
        self.maxJointForce = 1.
        self.basejoints = [0,1,2] # phantom joints
        self.additionalJoints = []
        self.JointIdx = self.basejoints + self.additionalJoints
        self.baseLink = 4
        super().__init__(bullet_client, physclient, urdf_path, start_pos, robot_loader=self.load_robot)
        self.defaultLink = self.baseLink
        
    @property
    def actions(self):
        return len(self.JointIdx)

    def load_robot(self, urdf_path, pose):
        # Load the base at 0,0,0
        # Movement happens through pose resetting via joint states
        # for this robot
        return self.bullet_client.loadURDF(fileName=urdf_path,
                                    basePosition=[0.,0.,0.], 
                                    baseOrientation=[0.,0.,0.,1.],
                                    physicsClientId=self.phys_client)

    def getPose(self):

        pos, orn  = self.bullet_client.getBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client)
        orn_euler = self.bullet_client.getEulerFromQuaternion(orn)
        return [pos[0],pos[1],orn_euler[2]]

    
    def setPose(self,pose):

        pos = [pose[0],pose[1],0.0]
        orn = [0,0,pose[2]]
        quat = self.bullet_client.getQuaternionFromEuler(orn)
        self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client,posObj=pos,ornObj=quat)
        return pose


class SlabRobo(PbRobot):
    action_space = spaces.Box(low=np.array([-1,-1,-1]),
                              high=np.array([1,1,1]))

    def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
        self.maxJointForce = 1.
        self.basejoints = [0,1,2] # phantom joints
        self.additionalJoints = []
        self.JointIdx = self.basejoints + self.additionalJoints
        self.baseLink = 4
        super().__init__(bullet_client, physclient, urdf_path, start_pos, robot_loader=self.load_robot)
        self.defaultLink = self.baseLink
        
    @property
    def actions(self):
        return len(self.JointIdx)

    def load_robot(self, urdf_path, pose):
        # Load the base at 0,0,0
        # Movement happens through pose resetting via joint states
        # for this robot
        return self.bullet_client.loadURDF(fileName=urdf_path,
                                    basePosition=[0.,0.,0.], 
                                    baseOrientation=[0.,0.,0.,1.],
                                    physicsClientId=self.phys_client)

    def getPose(self):

        pos, orn  = self.bullet_client.getBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client)
        orn_euler = self.bullet_client.getEulerFromQuaternion(orn)
        return [pos[0],pos[1],orn_euler[2]]
        return [pos[0],pos[1],orn_euler[2]]

    
    def setPose(self,pose):

        pos = [pose[0],pose[1],0.0]
        orn = [0,0,pose[2]]
        quat = self.bullet_client.getQuaternionFromEuler(orn)
        self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,physicsClientId=self.phys_client,posObj=pos,ornObj=quat)
        return pose
   
class Turtlebot(PbRobot):

    action_space = spaces.Box(low=np.array([-2,-2]),
                                  high=np.array([2,2]))

    def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
        self.additionalJoints = [1,2] # wheel joints
        self.jointIdx = self.additionalJoints
        self.maxJointForce = 100.
        self.baseLink = 2
        super().__init__(bullet_client, physclient, urdf_path, start_pos)
        self.defaultLink = self.baseLink
        
    @property
    def actions(self):
        return len(self.wheelJointIdx)

    def getPose(self):
        

        active_joints = self.getJointStates(self.additionalJoints)
        active_joint_vals = [i[0] for i in active_joints]
        return super().getPose(), active_joint_vals

    def getJointStates(self):
        return self.bullet_client.getJointStates(self.bodyId,
                                           jointIndices=self.wheelJointIdx)
    
    def setJoints(self, targetPos):
        self.bullet_client.setJointMotorControlArray(bodyUniqueId=self.bodyId,
                                                    jointIndices=self.wheelJointIdx,
                                                    controlMode=self.bullet_client.POSITION_CONTROL,
                                                    targetPositions=targetPos,
                                                    targetVelocities=[0.]*len(self.wheelJointIdx),
                                                    forces=[self.maxJointForce]*len(self.wheelJointIdx),
                                                    physicsClientId=self.phys_client)

    def resetJoints(self, joint_states):
        for i, jointId in enumerate(self.wheelJointIdx):
            self.bullet_client.resetJointState(bodyUniqueId=self.bodyId,
                                               jointIndex=jointId,
                                               targetValue=joint_states[i][0],
                                               targetVelocity=joint_states[i][1],
                                               physicsClientId=self.phys_client)
    

    def resetFullPose(self, translation, rotation, joints):

        self.resetPose(translation, rotation)
        joint_states = [[i,0] for i in joints]
        self.resetJoints(joint_states=joint_states)


# class SlabRobo(PbRobot):

#     action_space = spaces.Box(low=np.array([-0.1,-0.1,-0.314]),
#                               high=np.array([0.1,0.1,0.314]))

#     def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
#         self.maxJointForce = 1.
#         self.basejoints = [0,1,2] # phantom joints
#         self.additionalJoints = []
#         self.JointIdx = self.basejoints + self.additionalJoints
#         self.baseLink = 4
#         super().__init__(bullet_client, physclient, urdf_path, start_pos, robot_loader=self.load_robot)
#         self.defaultLink = self.baseLink
        
#     @property
#     def actions(self):
#         return len(self.JointIdx)

#     def load_robot(self, urdf_path, pose):
#         # Load the base at 0,0,0
#         # Movement happens through pose resetting via joint states
#         # for this robot
#         return self.bullet_client.loadURDF(fileName=urdf_path,
#                                     basePosition=[0.,0.,0.], 
#                                     baseOrientation=[0.,0.,0.,1.],
#                                     physicsClientId=self.phys_client)

#     def getPose(self):
#         link_state = self.bullet_client.getLinkState(self.bodyId, self.baseLink)
#         active_joint_vals = [y[0] for y in self.bullet_client.getJointStates(self.bodyId,[0,1,2])]
#         return (link_state[0], link_state[1]),active_joint_vals
    
#     def resetPose(self, pose, z):

#         # rotation_euler = self.bullet_client.getEulerFromQuaternion(rotation)

#         self.bullet_client.resetJointState(self.bodyId, 0, pose[0]) # X
#         self.bullet_client.resetJointState(self.bodyId, 1, pose[1]) # Y
#         self.bullet_client.resetJointState(self.bodyId, 3, z) # Z
#         self.bullet_client.resetJointState(self.bodyId, 2, pose[2]) # Yaw

#         # self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,
#         #                                         posObj=[0.,0.,z], # only change the z
#         #                                         ornObj=[0.,0.,0.,1.],
#         #                                         physicsClientId=self.phys_client)
    
#     def resetFullPose(self, pose,z):

#         self.bullet_client.resetJointState(self.bodyId, 0, pose[0]) # X
#         self.bullet_client.resetJointState(self.bodyId, 1, pose[1]) # Y
#         self.bullet_client.resetJointState(self.bodyId, 3, z) # Z
#         self.bullet_client.resetJointState(self.bodyId, 2, pose[2]) # Yaw

#         # self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,
#         #                                         posObj=[0.,0.,z], # only change the z
#         #                                         ornObj=[0.,0.,0.,1.],
#         #                                         physicsClientId=self.phys_client)

#     def getJointStates(self):
#         return self.bullet_client.getJointStates(self.bodyId,
#                                            jointIndices=self.JointIdx)
    
#     def setJoints(self, targetPos):
#         self.bullet_client.setJointMotorControlArray(bodyUniqueId=self.bodyId,
#                                                     jointIndices=self.JointIdx,
#                                                     controlMode=self.bullet_client.POSITION_CONTROL,
#                                                     targetPositions=targetPos,
#                                                     forces=[self.maxJointForce]*len(self.JointIdx),
#                                                     physicsClientId=self.phys_client)
    
#     def resetJoints(self, joint_states):
#         for i, jointId in enumerate(self.JointIdx):
#             self.bullet_client.resetJointState(bodyUniqueId=self.bodyId,
#                                                jointIndex=jointId,
#                                                targetValue=joint_states[i][0],
#                                                targetVelocity=joint_states[i][1],
#                                                physicsClientId=self.phys_client)

class HingedRobo(PbRobot):

    action_space = spaces.Box(low=np.array([-0.1,-0.1,-0.314,-0.157]),
                              high=np.array([0.1,0.1,0.314,0.157]))

    def __init__(self, bullet_client, physclient, urdf_path, start_pos):
        
        self.maxJointForce = 100.
        self.basejoints = [0,1,2] # Phantom joints
        self.additionalJoints = [5] # Hinge joint between the two links
        self.JointIdx = self.basejoints+self.additionalJoints
        self.baseLink = [4,5] # Two contact links that are used to check collision with environment
        super().__init__(bullet_client, physclient, urdf_path, start_pos, robot_loader=self.load_robot)
        self.defaultLink = self.baseLink
        
    @property
    def actions(self):
        return len(self.JointIdx)

    def load_robot(self, urdf_path, pose):
        # Load the base at 0,0,0
        # Movement happens through pose resetting via joint states
        # for this robot
        return self.bullet_client.loadURDF(fileName=urdf_path,
                                    basePosition=[0.,0.,0.], 
                                    baseOrientation=[0.,0.,0.,1.],
                                    physicsClientId=self.phys_client)

    def getPose(self):
        link_state = self.bullet_client.getLinkState(self.bodyId, self.baseLink[0])
        active_joints = self.getJointStates([0,1,2,3])
        active_joint_vals = [i[0] for i in active_joints]

        return (link_state[0], link_state[1]), active_joint_vals
    
    def resetPose(self, translation, rotation):

        rotation_euler = self.bullet_client.getEulerFromQuaternion(rotation)

        self.bullet_client.resetJointState(self.bodyId, 0, translation[0]) # X
        self.bullet_client.resetJointState(self.bodyId, 1, translation[1]) # Y
        self.bullet_client.resetJointState(self.bodyId, 3, translation[2]) # Z
        self.bullet_client.resetJointState(self.bodyId, 2, rotation_euler[2]) # Yaw

        # self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,
        #                                         posObj=[0.,0.,translation[2]], # only change the z
        #                                         ornObj=[0.,0.,0.,1.],
        #                                         physicsClientId=self.phys_client)
    
    def resetFullPose(self, joints,z):

        # rotation_euler = self.bullet_client.getEulerFromQuaternion(rotation)

        self.bullet_client.resetJointState(self.bodyId, 0, joints[0]) # X
        self.bullet_client.resetJointState(self.bodyId, 1, joints[1]) # Y
        self.bullet_client.resetJointState(self.bodyId, 3, z) # Z
        self.bullet_client.resetJointState(self.bodyId, 2, joints[2]) # Yaw
        self.bullet_client.resetJointState(self.bodyId, 5, joints[3])
        
        # self.bullet_client.resetBasePositionAndOrientation(bodyUniqueId=self.bodyId,
        #                                         posObj=[0.,0.,translation[2]], # only change the z
        #                                         ornObj=[0.,0.,0.,1.],
        #                                         physicsClientId=self.phys_client)

        joint_states = [[i,0] for i in joints]
        self.resetJoints(joint_states=joint_states, jointIdx=self.additionalJoints)

    def getJointStates(self,jointIdx=None):
        joints = self.JointIdx if jointIdx is None else jointIdx
        return self.bullet_client.getJointStates(self.bodyId,
                                           jointIndices=joints)
    
    def setJoints(self, targetPos,jointIdx=None):
        joints = self.JointIdx if jointIdx is None else jointIdx
        self.bullet_client.setJointMotorControlArray(bodyUniqueId=self.bodyId,
                                                    jointIndices=self.JointIdx,
                                                    controlMode=self.bullet_client.POSITION_CONTROL,
                                                    targetPositions=targetPos,
                                                    forces=[self.maxJointForce]*len(self.JointIdx),
                                                    physicsClientId=self.phys_client)
    
    def resetJoints(self, joint_states, jointIdx=None):
        joints = self.JointIdx if jointIdx is None else jointIdx
        for i, jointId in enumerate(joints):
            self.bullet_client.resetJointState(bodyUniqueId=self.bodyId,
                                               jointIndex=jointId,
                                               targetValue=joint_states[i][0],
                                               targetVelocity=joint_states[i][1],
                                               physicsClientId=self.phys_client)

robotsDict = {
    "husky": Husky,
    "turtlebot": Turtlebot,
    "slab":SlabRobo,
    "puck":SlabRobo,
    "hinged":HingedRobo,
    "limo": LimoRobot,
}